__init__.py 8.8 KB
Newer Older
G
guru4elephant 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

M
MRXLT 已提交
15 16
import paddle_serving_client
import os
17 18 19
from .proto import sdk_configure_pb2 as sdk
from .proto import general_model_config_pb2 as m_config
import google.protobuf.text_format
G
guru4elephant 已提交
20
import time
21
import sys
G
guru4elephant 已提交
22

G
guru4elephant 已提交
23 24 25
int_type = 0
float_type = 1

M
MRXLT 已提交
26

G
guru4elephant 已提交
27 28 29 30 31 32 33 34 35 36 37 38 39 40
class SDKConfig(object):
    def __init__(self):
        self.sdk_desc = sdk.SDKConf()
        self.endpoints = []

    def set_server_endpoints(self, endpoints):
        self.endpoints = endpoints

    def gen_desc(self):
        predictor_desc = sdk.Predictor()
        predictor_desc.name = "general_model"
        predictor_desc.service_name = \
            "baidu.paddle_serving.predictor.general_model.GeneralModelService"
        predictor_desc.endpoint_router = "WeightedRandomRender"
G
guru4elephant 已提交
41
        predictor_desc.weighted_random_render_conf.variant_weight_list = "100"
G
guru4elephant 已提交
42 43 44

        variant_desc = sdk.VariantConf()
        variant_desc.tag = "var1"
M
MRXLT 已提交
45 46
        variant_desc.naming_conf.cluster = "list://{}".format(":".join(
            self.endpoints))
G
guru4elephant 已提交
47 48 49 50 51 52 53 54 55 56 57 58

        predictor_desc.variants.extend([variant_desc])

        self.sdk_desc.predictors.extend([predictor_desc])
        self.sdk_desc.default_variant_conf.tag = "default"
        self.sdk_desc.default_variant_conf.connection_conf.connect_timeout_ms = 2000
        self.sdk_desc.default_variant_conf.connection_conf.rpc_timeout_ms = 20000
        self.sdk_desc.default_variant_conf.connection_conf.connect_retry_count = 2
        self.sdk_desc.default_variant_conf.connection_conf.max_connection_per_host = 100
        self.sdk_desc.default_variant_conf.connection_conf.hedge_request_timeout_ms = -1
        self.sdk_desc.default_variant_conf.connection_conf.hedge_fetch_retry_count = 2
        self.sdk_desc.default_variant_conf.connection_conf.connection_type = "pooled"
M
MRXLT 已提交
59

G
guru4elephant 已提交
60 61 62 63 64 65 66 67
        self.sdk_desc.default_variant_conf.naming_conf.cluster_filter_strategy = "Default"
        self.sdk_desc.default_variant_conf.naming_conf.load_balance_strategy = "la"

        self.sdk_desc.default_variant_conf.rpc_parameter.compress_type = 0
        self.sdk_desc.default_variant_conf.rpc_parameter.package_size = 20
        self.sdk_desc.default_variant_conf.rpc_parameter.protocol = "baidu_std"
        self.sdk_desc.default_variant_conf.rpc_parameter.max_channel_per_request = 3

G
guru4elephant 已提交
68
        return self.sdk_desc
G
guru4elephant 已提交
69

G
guru4elephant 已提交
70 71 72 73 74 75

class Client(object):
    def __init__(self):
        self.feed_names_ = []
        self.fetch_names_ = []
        self.client_handle_ = None
76
        self.result_handle_ = None
M
MRXLT 已提交
77
        self.feed_shapes_ = {}
G
guru4elephant 已提交
78
        self.feed_types_ = {}
G
guru4elephant 已提交
79
        self.feed_names_to_idx_ = {}
M
MRXLT 已提交
80
        self.rpath()
M
MRXLT 已提交
81
        self.pid = os.getpid()
M
MRXLT 已提交
82 83 84 85 86 87 88

    def rpath(self):
        lib_path = os.path.dirname(paddle_serving_client.__file__)
        client_path = os.path.join(lib_path, 'serving_client.so')
        lib_path = os.path.join(lib_path, 'lib')
        os.popen('patchelf --set-rpath {} {}'.format(lib_path, client_path))

G
guru4elephant 已提交
89
    def load_client_config(self, path):
M
MRXLT 已提交
90
        from .serving_client import PredictorClient
91
        from .serving_client import PredictorRes
92 93 94 95 96
        model_conf = m_config.GeneralModelConfig()
        f = open(path, 'r')
        model_conf = google.protobuf.text_format.Merge(
            str(f.read()), model_conf)

G
guru4elephant 已提交
97 98 99 100
        # load configuraion here
        # get feed vars, fetch vars
        # get feed shapes, feed types
        # map feed names to index
101
        self.result_handle_ = PredictorRes()
G
guru4elephant 已提交
102 103
        self.client_handle_ = PredictorClient()
        self.client_handle_.init(path)
104
        read_env_flags = ["profile_client", "profile_server"]
M
MRXLT 已提交
105 106
        self.client_handle_.init_gflags([sys.argv[
            0]] + ["--tryfromenv=" + ",".join(read_env_flags)])
107 108
        self.feed_names_ = [var.alias_name for var in model_conf.feed_var]
        self.fetch_names_ = [var.alias_name for var in model_conf.fetch_var]
G
guru4elephant 已提交
109
        self.feed_names_to_idx_ = {}
G
guru4elephant 已提交
110 111
        self.fetch_names_to_type_ = {}
        self.fetch_names_to_idx_ = {}
M
MRXLT 已提交
112
        self.lod_tensor_set = set()
M
MRXLT 已提交
113
        self.feed_tensor_len = {}
114 115 116
        for i, var in enumerate(model_conf.feed_var):
            self.feed_names_to_idx_[var.alias_name] = i
            self.feed_types_[var.alias_name] = var.feed_type
M
MRXLT 已提交
117
            self.feed_shapes_[var.alias_name] = var.shape
M
MRXLT 已提交
118

M
MRXLT 已提交
119 120
            if var.is_lod_tensor:
                self.lod_tensor_set.add(var.alias_name)
M
MRXLT 已提交
121 122 123 124 125
            else:
                counter = 1
                for dim in self.feed_shapes_[var.alias_name]:
                    counter *= dim
                self.feed_tensor_len[var.alias_name] = counter
G
guru4elephant 已提交
126

G
guru4elephant 已提交
127 128 129 130
        for i, var in enumerate(model_conf.fetch_var):
            self.fetch_names_to_idx_[var.alias_name] = i
            self.fetch_names_to_type_[var.alias_name] = var.fetch_type

G
guru4elephant 已提交
131 132
        return

G
guru4elephant 已提交
133
    def connect(self, endpoints):
G
guru4elephant 已提交
134 135 136
        # check whether current endpoint is available
        # init from client config
        # create predictor here
G
guru4elephant 已提交
137 138 139
        predictor_sdk = SDKConfig()
        predictor_sdk.set_server_endpoints(endpoints)
        sdk_desc = predictor_sdk.gen_desc()
G
guru4elephant 已提交
140
        print(sdk_desc)
M
MRXLT 已提交
141 142
        self.client_handle_.create_predictor_by_desc(sdk_desc.SerializeToString(
        ))
G
guru4elephant 已提交
143 144 145 146 147 148 149

    def get_feed_names(self):
        return self.feed_names_

    def get_fetch_names(self):
        return self.fetch_names_

M
MRXLT 已提交
150 151 152 153
    def shape_check(self, feed, key):
        seq_shape = 1
        if key in self.lod_tensor_set:
            return
M
MRXLT 已提交
154
        if len(feed[key]) != self.feed_tensor_len[key]:
M
MRXLT 已提交
155 156 157
            raise SystemExit("The shape of feed tensor {} not match.".format(
                key))

G
guru4elephant 已提交
158
    def predict(self, feed={}, fetch=[]):
G
guru4elephant 已提交
159 160 161 162 163
        int_slot = []
        float_slot = []
        int_feed_names = []
        float_feed_names = []
        fetch_names = []
M
MRXLT 已提交
164

G
guru4elephant 已提交
165
        for key in feed:
M
MRXLT 已提交
166
            self.shape_check(feed, key)
G
guru4elephant 已提交
167 168 169 170
            if key not in self.feed_names_:
                continue
            if self.feed_types_[key] == int_type:
                int_feed_names.append(key)
G
guru4elephant 已提交
171
                int_slot.append(feed[key])
G
guru4elephant 已提交
172 173
            elif self.feed_types_[key] == float_type:
                float_feed_names.append(key)
G
guru4elephant 已提交
174
                float_slot.append(feed[key])
G
guru4elephant 已提交
175 176 177 178 179

        for key in fetch:
            if key in self.fetch_names_:
                fetch_names.append(key)

M
MRXLT 已提交
180 181
        ret = self.client_handle_.predict(float_slot, float_feed_names,
                                          int_slot, int_feed_names, fetch_names,
M
MRXLT 已提交
182
                                          self.result_handle_, self.pid)
M
MRXLT 已提交
183

G
guru4elephant 已提交
184
        result_map = {}
G
guru4elephant 已提交
185
        for i, name in enumerate(fetch_names):
G
guru4elephant 已提交
186
            if self.fetch_names_to_type_[name] == int_type:
M
MRXLT 已提交
187 188
                result_map[name] = self.result_handle_.get_int64_by_name(name)[
                    0]
G
guru4elephant 已提交
189
            elif self.fetch_names_to_type_[name] == float_type:
M
MRXLT 已提交
190 191
                result_map[name] = self.result_handle_.get_float_by_name(name)[
                    0]
M
MRXLT 已提交
192

G
guru4elephant 已提交
193 194
        return result_map

G
guru4elephant 已提交
195
    def batch_predict(self, feed_batch=[], fetch=[]):
M
MRXLT 已提交
196 197 198 199 200
        int_slot_batch = []
        float_slot_batch = []
        int_feed_names = []
        float_feed_names = []
        fetch_names = []
M
MRXLT 已提交
201
        counter = 0
M
MRXLT 已提交
202 203 204 205 206 207 208
        for feed in feed_batch:
            int_slot = []
            float_slot = []
            for key in feed:
                if key not in self.feed_names_:
                    continue
                if self.feed_types_[key] == int_type:
M
MRXLT 已提交
209 210
                    if counter == 0:
                        int_feed_names.append(key)
M
MRXLT 已提交
211 212
                    int_slot.append(feed[key])
                elif self.feed_types_[key] == float_type:
M
MRXLT 已提交
213 214
                    if counter == 0:
                        float_feed_names.append(key)
M
MRXLT 已提交
215
                    float_slot.append(feed[key])
M
MRXLT 已提交
216
            counter += 1
M
MRXLT 已提交
217 218 219 220 221 222 223
            int_slot_batch.append(int_slot)
            float_slot_batch.append(float_slot)

        for key in fetch:
            if key in self.fetch_names_:
                fetch_names.append(key)

M
MRXLT 已提交
224
        result_batch = self.client_handle_.batch_predict(
M
MRXLT 已提交
225
            float_slot_batch, float_feed_names, int_slot_batch, int_feed_names,
M
MRXLT 已提交
226
            fetch_names)
M
MRXLT 已提交
227 228

        result_map_batch = []
M
MRXLT 已提交
229
        for result in result_batch:
M
MRXLT 已提交
230 231 232 233 234
            result_map = {}
            for i, name in enumerate(fetch_names):
                result_map[name] = result[i]
            result_map_batch.append(result_map)

M
MRXLT 已提交
235
        return result_map_batch
236 237 238

    def release(self):
        self.client_handle_.destroy_predictor()
G
guru4elephant 已提交
239
        self.client_handle_ = None