__init__.py 12.6 KB
Newer Older
G
guru4elephant 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14
# pylint: disable=doc-string-missing
G
guru4elephant 已提交
15

M
MRXLT 已提交
16 17
import paddle_serving_client
import os
18 19 20
from .proto import sdk_configure_pb2 as sdk
from .proto import general_model_config_pb2 as m_config
import google.protobuf.text_format
D
dongdaxiang 已提交
21 22
import numpy as np
import time
23
import sys
G
guru4elephant 已提交
24

G
guru4elephant 已提交
25 26 27
int_type = 0
float_type = 1

M
MRXLT 已提交
28

G
guru4elephant 已提交
29 30 31
class SDKConfig(object):
    def __init__(self):
        self.sdk_desc = sdk.SDKConf()
32 33 34
        self.tag_list = []
        self.cluster_list = []
        self.variant_weight_list = []
G
guru4elephant 已提交
35

36 37 38 39
    def add_server_variant(self, tag, cluster, variant_weight):
        self.tag_list.append(tag)
        self.cluster_list.append(cluster)
        self.variant_weight_list.append(variant_weight)
G
guru4elephant 已提交
40 41 42 43 44 45 46

    def gen_desc(self):
        predictor_desc = sdk.Predictor()
        predictor_desc.name = "general_model"
        predictor_desc.service_name = \
            "baidu.paddle_serving.predictor.general_model.GeneralModelService"
        predictor_desc.endpoint_router = "WeightedRandomRender"
47 48
        predictor_desc.weighted_random_render_conf.variant_weight_list = "|".join(
            self.variant_weight_list)
G
guru4elephant 已提交
49

50 51 52 53 54 55
        for idx, tag in enumerate(self.tag_list):
            variant_desc = sdk.VariantConf()
            variant_desc.tag = tag
            variant_desc.naming_conf.cluster = "list://{}".format(",".join(
                self.cluster_list[idx]))
            predictor_desc.variants.extend([variant_desc])
G
guru4elephant 已提交
56 57 58 59 60 61 62 63 64 65

        self.sdk_desc.predictors.extend([predictor_desc])
        self.sdk_desc.default_variant_conf.tag = "default"
        self.sdk_desc.default_variant_conf.connection_conf.connect_timeout_ms = 2000
        self.sdk_desc.default_variant_conf.connection_conf.rpc_timeout_ms = 20000
        self.sdk_desc.default_variant_conf.connection_conf.connect_retry_count = 2
        self.sdk_desc.default_variant_conf.connection_conf.max_connection_per_host = 100
        self.sdk_desc.default_variant_conf.connection_conf.hedge_request_timeout_ms = -1
        self.sdk_desc.default_variant_conf.connection_conf.hedge_fetch_retry_count = 2
        self.sdk_desc.default_variant_conf.connection_conf.connection_type = "pooled"
M
MRXLT 已提交
66

G
guru4elephant 已提交
67 68 69 70 71 72 73 74
        self.sdk_desc.default_variant_conf.naming_conf.cluster_filter_strategy = "Default"
        self.sdk_desc.default_variant_conf.naming_conf.load_balance_strategy = "la"

        self.sdk_desc.default_variant_conf.rpc_parameter.compress_type = 0
        self.sdk_desc.default_variant_conf.rpc_parameter.package_size = 20
        self.sdk_desc.default_variant_conf.rpc_parameter.protocol = "baidu_std"
        self.sdk_desc.default_variant_conf.rpc_parameter.max_channel_per_request = 3

G
guru4elephant 已提交
75
        return self.sdk_desc
G
guru4elephant 已提交
76

G
guru4elephant 已提交
77 78 79 80 81 82

class Client(object):
    def __init__(self):
        self.feed_names_ = []
        self.fetch_names_ = []
        self.client_handle_ = None
83
        self.result_handle_ = None
M
MRXLT 已提交
84
        self.feed_shapes_ = {}
G
guru4elephant 已提交
85
        self.feed_types_ = {}
G
guru4elephant 已提交
86
        self.feed_names_to_idx_ = {}
M
MRXLT 已提交
87
        self.rpath()
M
MRXLT 已提交
88
        self.pid = os.getpid()
B
barrierye 已提交
89
        self.predictor_sdk_ = None
G
guru4elephant 已提交
90 91
        self.producers = []
        self.consumer = None
M
MRXLT 已提交
92 93 94 95 96

    def rpath(self):
        lib_path = os.path.dirname(paddle_serving_client.__file__)
        client_path = os.path.join(lib_path, 'serving_client.so')
        lib_path = os.path.join(lib_path, 'lib')
M
MRXLT 已提交
97
        os.system('patchelf --set-rpath {} {}'.format(lib_path, client_path))
M
MRXLT 已提交
98

G
guru4elephant 已提交
99
    def load_client_config(self, path):
M
MRXLT 已提交
100
        from .serving_client import PredictorClient
101
        from .serving_client import PredictorRes
102 103 104 105 106
        model_conf = m_config.GeneralModelConfig()
        f = open(path, 'r')
        model_conf = google.protobuf.text_format.Merge(
            str(f.read()), model_conf)

G
guru4elephant 已提交
107 108 109 110
        # load configuraion here
        # get feed vars, fetch vars
        # get feed shapes, feed types
        # map feed names to index
111
        self.result_handle_ = PredictorRes()
G
guru4elephant 已提交
112 113
        self.client_handle_ = PredictorClient()
        self.client_handle_.init(path)
M
bug fix  
MRXLT 已提交
114 115
        if "FLAGS_max_body_size" not in os.environ:
            os.environ["FLAGS_max_body_size"] = str(512 * 1024 * 1024)
M
MRXLT 已提交
116
        read_env_flags = ["profile_client", "profile_server", "max_body_size"]
M
MRXLT 已提交
117 118
        self.client_handle_.init_gflags([sys.argv[
            0]] + ["--tryfromenv=" + ",".join(read_env_flags)])
119 120
        self.feed_names_ = [var.alias_name for var in model_conf.feed_var]
        self.fetch_names_ = [var.alias_name for var in model_conf.fetch_var]
G
guru4elephant 已提交
121
        self.feed_names_to_idx_ = {}
G
guru4elephant 已提交
122 123
        self.fetch_names_to_type_ = {}
        self.fetch_names_to_idx_ = {}
M
MRXLT 已提交
124
        self.lod_tensor_set = set()
M
MRXLT 已提交
125
        self.feed_tensor_len = {}
126

127 128 129
        for i, var in enumerate(model_conf.feed_var):
            self.feed_names_to_idx_[var.alias_name] = i
            self.feed_types_[var.alias_name] = var.feed_type
M
MRXLT 已提交
130
            self.feed_shapes_[var.alias_name] = var.shape
M
MRXLT 已提交
131

M
MRXLT 已提交
132 133
            if var.is_lod_tensor:
                self.lod_tensor_set.add(var.alias_name)
M
MRXLT 已提交
134 135 136 137 138
            else:
                counter = 1
                for dim in self.feed_shapes_[var.alias_name]:
                    counter *= dim
                self.feed_tensor_len[var.alias_name] = counter
G
guru4elephant 已提交
139 140 141
        for i, var in enumerate(model_conf.fetch_var):
            self.fetch_names_to_idx_[var.alias_name] = i
            self.fetch_names_to_type_[var.alias_name] = var.fetch_type
142 143
            if var.is_lod_tensor:
                self.lod_tensor_set.add(var.alias_name)
G
guru4elephant 已提交
144 145
        return

146
    def add_variant(self, tag, cluster, variant_weight):
B
barrierye 已提交
147 148
        if self.predictor_sdk_ is None:
            self.predictor_sdk_ = SDKConfig()
149 150 151
        self.predictor_sdk_.add_server_variant(tag, cluster,
                                               str(variant_weight))

B
barrierye 已提交
152
    def connect(self, endpoints=None):
G
guru4elephant 已提交
153 154 155
        # check whether current endpoint is available
        # init from client config
        # create predictor here
B
barrierye 已提交
156 157 158 159 160 161 162
        if endpoints is None:
            if self.predictor_sdk_ is None:
                raise SystemExit(
                    "You must set the endpoints parameter or use add_variant function to create a variant."
                )
        else:
            if self.predictor_sdk_ is None:
163
                self.add_variant('default_tag_{}'.format(id(self)), endpoints,
164
                                 100)
B
barrierye 已提交
165 166
            else:
                print(
167
                    "parameter endpoints({}) will not take effect, because you use the add_variant function.".
B
barrierye 已提交
168
                    format(endpoints))
169
        sdk_desc = self.predictor_sdk_.gen_desc()
M
MRXLT 已提交
170 171
        self.client_handle_.create_predictor_by_desc(sdk_desc.SerializeToString(
        ))
G
guru4elephant 已提交
172 173 174 175 176 177 178

    def get_feed_names(self):
        return self.feed_names_

    def get_fetch_names(self):
        return self.fetch_names_

M
MRXLT 已提交
179 180 181
    def shape_check(self, feed, key):
        if key in self.lod_tensor_set:
            return
M
MRXLT 已提交
182
        if len(feed[key]) != self.feed_tensor_len[key]:
M
MRXLT 已提交
183 184 185
            raise SystemExit("The shape of feed tensor {} not match.".format(
                key))

186
    def predict(self, feed=None, fetch=None, need_variant_tag=False):
G
guru4elephant 已提交
187 188 189
        if feed is None or fetch is None:
            raise ValueError("You should specify feed and fetch for prediction")

190 191 192 193 194 195
        fetch_list = []
        if isinstance(fetch, str):
            fetch_list = [fetch]
        elif isinstance(fetch, list):
            fetch_list = fetch
        else:
M
MRXLT 已提交
196
            raise ValueError("Fetch only accepts string and list of string")
197 198 199 200 201 202 203

        feed_batch = []
        if isinstance(feed, dict):
            feed_batch.append(feed)
        elif isinstance(feed, list):
            feed_batch = feed
        else:
M
MRXLT 已提交
204
            raise ValueError("Feed only accepts dict and list of dict")
G
guru4elephant 已提交
205

M
MRXLT 已提交
206 207 208 209
        int_slot_batch = []
        float_slot_batch = []
        int_feed_names = []
        float_feed_names = []
D
dongdaxiang 已提交
210 211
        int_shape = []
        float_shape = []
M
MRXLT 已提交
212
        fetch_names = []
M
MRXLT 已提交
213
        counter = 0
M
MRXLT 已提交
214
        batch_size = len(feed_batch)
215 216 217 218 219 220 221

        for key in fetch_list:
            if key in self.fetch_names_:
                fetch_names.append(key)

        if len(fetch_names) == 0:
            raise ValueError(
M
MRXLT 已提交
222
                "Fetch names should not be empty or out of saved fetch list.")
223 224
            return {}

G
guru4elephant 已提交
225
        for i, feed_i in enumerate(feed_batch):
M
MRXLT 已提交
226 227
            int_slot = []
            float_slot = []
228
            for key in feed_i:
M
MRXLT 已提交
229
                if key not in self.feed_names_:
M
MRXLT 已提交
230
                    raise ValueError("Wrong feed name: {}.".format(key))
231 232
                if not isinstance(feed_i[key], np.ndarray):
                    self.shape_check(feed_i, key)
M
MRXLT 已提交
233
                if self.feed_types_[key] == int_type:
G
guru4elephant 已提交
234
                    if i == 0:
M
MRXLT 已提交
235
                        int_feed_names.append(key)
D
dongdaxiang 已提交
236
                        if isinstance(feed_i[key], np.ndarray):
237
                            int_shape.append(list(feed_i[key].shape))
D
dongdaxiang 已提交
238 239
                        else:
                            int_shape.append(self.feed_shapes_[key])
D
dongdaxiang 已提交
240
                    if isinstance(feed_i[key], np.ndarray):
241
                        int_slot.append(np.reshape(feed_i[key], (-1)).tolist())
D
dongdaxiang 已提交
242 243
                    else:
                        int_slot.append(feed_i[key])
M
MRXLT 已提交
244
                elif self.feed_types_[key] == float_type:
G
guru4elephant 已提交
245
                    if i == 0:
M
MRXLT 已提交
246
                        float_feed_names.append(key)
D
dongdaxiang 已提交
247
                        if isinstance(feed_i[key], np.ndarray):
248
                            float_shape.append(list(feed_i[key].shape))
D
dongdaxiang 已提交
249 250
                        else:
                            float_shape.append(self.feed_shapes_[key])
D
dongdaxiang 已提交
251
                    if isinstance(feed_i[key], np.ndarray):
252 253
                        float_slot.append(
                            np.reshape(feed_i[key], (-1)).tolist())
D
dongdaxiang 已提交
254 255
                    else:
                        float_slot.append(feed_i[key])
M
MRXLT 已提交
256 257 258
            int_slot_batch.append(int_slot)
            float_slot_batch.append(float_slot)

M
MRXLT 已提交
259
        result_batch = self.result_handle_
M
MRXLT 已提交
260
        res = self.client_handle_.batch_predict(
261 262
            float_slot_batch, float_feed_names, float_shape, int_slot_batch,
            int_feed_names, int_shape, fetch_names, result_batch, self.pid)
M
MRXLT 已提交
263

264 265 266
        if res == -1:
            return None

B
barrierye 已提交
267
        multi_result_map = []
B
barrierye 已提交
268 269
        model_engine_names = result_batch.get_engine_names()
        for mi, engine_name in enumerate(model_engine_names):
B
barrierye 已提交
270
            result_map = {}
B
barrierye 已提交
271
            # result map needs to be a numpy array
B
barrierye 已提交
272 273
            for i, name in enumerate(fetch_names):
                if self.fetch_names_to_type_[name] == int_type:
B
barrierye 已提交
274
                    result_map[name] = result_batch.get_int64_by_name(mi, name)
B
barrierye 已提交
275
                    shape = result_batch.get_shape(mi, name)
W
WangXi 已提交
276
                    result_map[name] = np.array(result_map[name], dtype='int64')
B
barrierye 已提交
277 278 279 280
                    result_map[name].shape = shape
                    if name in self.lod_tensor_set:
                        result_map["{}.lod".format(
                            name)] = result_batch.get_lod(mi, name)
B
barrierye 已提交
281
                elif self.fetch_names_to_type_[name] == float_type:
B
barrierye 已提交
282
                    result_map[name] = result_batch.get_float_by_name(mi, name)
B
barrierye 已提交
283
                    shape = result_batch.get_shape(mi, name)
W
WangXi 已提交
284 285
                    result_map[name] = np.array(
                        result_map[name], dtype='float32')
B
barrierye 已提交
286 287 288 289 290
                    result_map[name].shape = shape
                    if name in self.lod_tensor_set:
                        result_map["{}.lod".format(
                            name)] = result_batch.get_lod(mi, name)
            multi_result_map.append(result_map)
B
barrierye 已提交
291

B
barrierye 已提交
292 293
        ret = None
        if len(model_engine_names) == 1:
B
barrierye 已提交
294 295
            # If only one model result is returned, the format of ret is result_map
            ret = multi_result_map[0]
G
guru4elephant 已提交
296
        else:
B
barrierye 已提交
297 298 299 300 301 302
            # If multiple model results are returned, the format of ret is {name: result_map}
            ret = {
                engine_name: multi_result_map[mi]
                for mi, engine_name in enumerate(model_engine_names)
            }

B
barrierye 已提交
303
        # When using the A/B test, the tag of variant needs to be returned
B
barrierye 已提交
304 305 306
        return ret if not need_variant_tag else [
            ret, self.result_handle_.variant_tag()
        ]
B
barrierye 已提交
307

308 309
    def release(self):
        self.client_handle_.destroy_predictor()
G
guru4elephant 已提交
310
        self.client_handle_ = None